619ff6
@@ -28,8 +28,6 @@
 import java.util.Set;
 import java.util.Stack;
 
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 import org.apache.hadoop.hive.ql.exec.JoinOperator;
 import org.apache.hadoop.hive.ql.exec.Operator;
 import org.apache.hadoop.hive.ql.exec.OperatorFactory;
@@ -49,12 +47,12 @@
 import org.apache.hadoop.hive.ql.lib.RuleRegExp;
 import org.apache.hadoop.hive.ql.metadata.Table;
 import org.apache.hadoop.hive.ql.parse.ParseContext;
-import org.apache.hadoop.hive.ql.parse.QBJoinTree;
 import org.apache.hadoop.hive.ql.parse.SemanticException;
 import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
 import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
 import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
 import org.apache.hadoop.hive.ql.plan.ExprNodeDesc.ExprNodeDescEqualityWrapper;
+import org.apache.hadoop.hive.ql.plan.ExprNodeDescUtils;
 import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
 import org.apache.hadoop.hive.ql.plan.FilterDesc;
 import org.apache.hadoop.hive.ql.plan.OperatorDesc;
@@ -70,6 +68,8 @@
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * SkewJoinOptimizer.
@@ -283,10 +283,11 @@
private boolean getTableScanOps(
      * @param op The join operator being optimized
      * @param tableScanOpsForJoin table scan operators which are parents of the join operator
      * @return map<join keys intersection skewedkeys, list of skewed values>.
+     * @throws SemanticException 
      */
     private Map<List<ExprNodeDesc>, List<List<String>>>
       getSkewedValues(
-        Operator<? extends OperatorDesc> op, List<TableScanOperator> tableScanOpsForJoin) {
+        Operator<? extends OperatorDesc> op, List<TableScanOperator> tableScanOpsForJoin) throws SemanticException {
 
       Map <List<ExprNodeDesc>, List<List<String>>> skewDataReturn =
         new HashMap<List<ExprNodeDesc>, List<List<String>>>();
@@ -299,6 +300,7 @@
private boolean getTableScanOps(
         ReduceSinkDesc rsDesc = ((ReduceSinkOperator) reduceSinkOp).getConf();
 
         if (rsDesc.getKeyCols() != null) {
+          TableScanOperator tableScanOp = null;
           Table table = null;
           // Find the skew information corresponding to the table
           List<String> skewedColumns = null;
@@ -321,7 +323,9 @@
private boolean getTableScanOps(
             if (keyColDesc instanceof ExprNodeColumnDesc) {
               keyCol = (ExprNodeColumnDesc) keyColDesc;
               if (table == null) {
-                table = getTable(parseContext, reduceSinkOp, tableScanOpsForJoin);
+                tableScanOp = getTableScanOperator(parseContext, reduceSinkOp, tableScanOpsForJoin);
+                table =
+                  tableScanOp == null ? null : tableScanOp.getConf().getTableMetadata();
                 skewedColumns =
                   table == null ? null : table.getSkewedColNames();
                 // No skew on the table to take care of
@@ -332,10 +336,13 @@
private boolean getTableScanOps(
                 skewedValueList =
                   table == null ? null : table.getSkewedColValues();
               }
-              int pos = skewedColumns.indexOf(keyCol.getColumn());
+              ExprNodeDesc keyColOrigin = ExprNodeDescUtils.backtrack(keyCol,
+                      reduceSinkOp, tableScanOp);
+              int pos = keyColOrigin == null || !(keyColOrigin instanceof ExprNodeColumnDesc) ?
+                      -1 : skewedColumns.indexOf(((ExprNodeColumnDesc)keyColOrigin).getColumn());
               if ((pos >= 0) && (!positionSkewedKeys.contains(pos))) {
                 positionSkewedKeys.add(pos);
-                ExprNodeColumnDesc keyColClone = (ExprNodeColumnDesc) keyCol.clone();
+                ExprNodeColumnDesc keyColClone = (ExprNodeColumnDesc) keyColOrigin.clone();
                 keyColClone.setTabAlias(null);
                 joinKeysSkewedCols.add(new ExprNodeDescEqualityWrapper(keyColClone));
               }
@@ -386,9 +393,9 @@
private boolean getTableScanOps(
     }
 
     /**
-     * Get the table alias from the candidate table scans.
+     * Get the table scan.
      */
-    private Table getTable(
+    private TableScanOperator getTableScanOperator(
       ParseContext parseContext,
       Operator<? extends OperatorDesc> op,
       List<TableScanOperator> tableScanOpsForJoin) {
@@ -396,7 +403,7 @@
private Table getTable(
         if (op instanceof TableScanOperator) {
           TableScanOperator tsOp = (TableScanOperator)op;
           if (tableScanOpsForJoin.contains(tsOp)) {
-            return tsOp.getConf().getTableMetadata();
+            return tsOp;
           }
         }
         if ((op.getParentOperators() == null) || (op.getParentOperators().isEmpty()) || 
